feat: add support for list of messages as team task input and update Society of Mind Agent (#4500)

* feat: add support for list of messages as team task input
* Update society of mind agent to use the list input task
---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Arun Brahma
2024-12-15 11:18:17 +05:30
committed by GitHub
parent c7145156b1
commit 7c0bbf674f
16 changed files with 361 additions and 134 deletions

View File

@@ -1,10 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args
from autogen_core import CancellationToken
from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..messages import (
AgentMessage,
ChatMessage,
TextMessage,
)
from ..state import BaseState
@@ -45,8 +49,9 @@ class BaseChatAgent(ChatAgent, ABC):
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
and the final item is the response. The base implementation in
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
@@ -55,7 +60,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
@@ -69,7 +74,14 @@ class BaseChatAgent(ChatAgent, ABC):
text_msg = TextMessage(content=task, source="user")
input_messages.append(text_msg)
output_messages.append(text_msg)
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
else:
@@ -83,7 +95,7 @@ class BaseChatAgent(ChatAgent, ABC):
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
@@ -99,7 +111,15 @@ class BaseChatAgent(ChatAgent, ABC):
input_messages.append(text_msg)
output_messages.append(text_msg)
yield text_msg
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
elif isinstance(task, list):
for msg in task:
if isinstance(msg, get_args(ChatMessage)[0]):
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, get_args(ChatMessage)[0]):
input_messages.append(task)
output_messages.append(task)
yield task

View File

@@ -1,10 +1,10 @@
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from autogen_core import CancellationToken, Image
from autogen_core.models import ChatCompletionClient
from autogen_core.models._types import SystemMessage
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
from autogen_agentchat.base import Response
from autogen_agentchat.state import SocietyOfMindAgentState
from ..base import TaskResult, Team
from ..messages import (
@@ -32,6 +32,10 @@ class SocietyOfMindAgent(BaseChatAgent):
team (Team): The team of agents to use.
model_client (ChatCompletionClient): The model client to use for preparing responses.
description (str, optional): The description of the agent.
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
Example:
@@ -39,35 +43,51 @@ class SocietyOfMindAgent(BaseChatAgent):
.. code-block:: python
import asyncio
from autogen_agentchat.ui import Console
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.conditions import TextMentionTermination
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
agent2 = AssistantAgent(
"assistant2",
model_client=model_client,
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
)
inner_termination = TextMentionTermination("APPROVE")
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
outter_termination = MaxMessageTermination(10)
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
agent3 = AssistantAgent(
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
)
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
stream = team.run_stream(task="Tell me a one-liner joke.")
async for message in stream:
print(message)
stream = team.run_stream(task="Write a short story with a surprising ending.")
await Console(stream)
asyncio.run(main())
"""
DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
"""str: The default instruction to use when generating a response using the
inner team's messages. The instruction will be prepended to the inner team's
messages when generating a response using the model. It assumes the role of
'system'."""
DEFAULT_RESPONSE_PROMPT = (
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
)
"""str: The default response prompt to use when generating a response using
the inner team's messages. It assumes the role of 'system'."""
def __init__(
self,
name: str,
@@ -75,17 +95,13 @@ class SocietyOfMindAgent(BaseChatAgent):
model_client: ChatCompletionClient,
*,
description: str = "An agent that uses an inner team of agents to generate responses.",
task_prompt: str = "{transcript}\nContinue.",
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
instruction: str = DEFAULT_INSTRUCTION,
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
) -> None:
super().__init__(name=name, description=description)
self._team = team
self._model_client = model_client
if "{transcript}" not in task_prompt:
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
self._task_prompt = task_prompt
if "{transcript}" not in response_prompt:
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
self._instruction = instruction
self._response_prompt = response_prompt
@property
@@ -104,33 +120,41 @@ class SocietyOfMindAgent(BaseChatAgent):
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentMessage | Response, None]:
# Build the context.
delta = list(messages)
task: str | None = None
if len(delta) > 0:
task = self._task_prompt.format(transcript=self._create_transcript(delta))
# Prepare the task for the team of agents.
task = list(messages)
# Run the team of agents.
result: TaskResult | None = None
inner_messages: List[AgentMessage] = []
count = 0
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
if isinstance(inner_msg, TaskResult):
result = inner_msg
else:
count += 1
if count <= len(task):
# Skip the task messages.
continue
yield inner_msg
inner_messages.append(inner_msg)
assert result is not None
if len(inner_messages) < 2:
# The first message is the task message so we need at least 2 messages.
if len(inner_messages) == 0:
yield Response(
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
)
else:
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
completion = await self._model_client.create(
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
# Generate a response using the model client.
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
llm_messages.extend(
[
UserMessage(content=message.content, source=message.source)
for message in inner_messages
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
]
)
llm_messages.append(SystemMessage(content=self._response_prompt))
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
assert isinstance(completion.content, str)
yield Response(
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
@@ -143,17 +167,11 @@ class SocietyOfMindAgent(BaseChatAgent):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
await self._team.reset()
def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
transcript = ""
for message in messages:
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
transcript += f"{message.source}: {message.content}\n"
elif isinstance(message, MultiModalMessage):
for content in message.content:
if isinstance(content, Image):
transcript += f"{message.source}: [Image]\n"
else:
transcript += f"{message.source}: {content}\n"
else:
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
return transcript
async def save_state(self) -> Mapping[str, Any]:
team_state = await self._team.save_state()
state = SocietyOfMindAgentState(inner_team_state=team_state)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
await self._team.load_state(society_of_mind_state.inner_team_state)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, Protocol, Sequence
from typing import AsyncGenerator, List, Protocol, Sequence
from autogen_core import CancellationToken
@@ -23,7 +23,7 @@ class TaskRunner(Protocol):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the task and return the result.
@@ -36,7 +36,7 @@ class TaskRunner(Protocol):
def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result

View File

@@ -8,6 +8,7 @@ from ._states import (
MagenticOneOrchestratorState,
RoundRobinManagerState,
SelectorManagerState,
SocietyOfMindAgentState,
SwarmManagerState,
TeamState,
)
@@ -22,4 +23,5 @@ __all__ = [
"SwarmManagerState",
"MagenticOneOrchestratorState",
"TeamState",
"SocietyOfMindAgentState",
]

View File

@@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState):
n_rounds: int = Field(default=0)
n_stalls: int = Field(default=0)
type: str = Field(default="MagenticOneOrchestratorState")
class SocietyOfMindAgentState(BaseState):
"""State for a Society of Mind agent."""
inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
type: str = Field(default="SocietyOfMindAgentState")

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, List, Mapping
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args
from autogen_core import (
AgentId,
@@ -19,7 +19,7 @@ from autogen_core._closure_agent import ClosureContext
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ...messages import AgentMessage, ChatMessage, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@@ -146,11 +146,18 @@ class BaseGroupChat(Team, ABC):
message: GroupChatStart | GroupChatMessage | GroupChatTermination,
ctx: MessageContext,
) -> None:
event_logger.info(message.message)
if isinstance(message, GroupChatTermination):
"""Collect output messages from the group chat."""
if isinstance(message, GroupChatStart):
if message.messages is not None:
for msg in message.messages:
event_logger.info(msg)
await self._output_message_queue.put(msg)
elif isinstance(message, GroupChatMessage):
event_logger.info(message.message)
await self._output_message_queue.put(message.message)
elif isinstance(message, GroupChatTermination):
event_logger.info(message.message)
self._stop_reason = message.message.content
return
await self._output_message_queue.put(message.message)
await ClosureAgent.register_closure(
runtime,
@@ -165,7 +172,7 @@ class BaseGroupChat(Team, ABC):
async def run(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> TaskResult:
"""Run the team and return the result. The base implementation uses
@@ -173,7 +180,7 @@ class BaseGroupChat(Team, ABC):
Once the team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
@@ -264,7 +271,7 @@ class BaseGroupChat(Team, ABC):
async def run_stream(
self,
*,
task: str | ChatMessage | None = None,
task: str | ChatMessage | List[ChatMessage] | None = None,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
@@ -272,7 +279,7 @@ class BaseGroupChat(Team, ABC):
team is stopped, the termination condition is reset.
Args:
task (str | ChatMessage | None): The task to run the team with.
task (str | ChatMessage | List[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
@@ -355,16 +362,20 @@ class BaseGroupChat(Team, ABC):
"""
# Create the first chat message if the task is a string or a chat message.
first_chat_message: ChatMessage | None = None
# Create the messages list if the task is a string or a chat message.
messages: List[ChatMessage] | None = None
if task is None:
pass
elif isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
first_chat_message = task
else:
raise ValueError(f"Invalid task type: {type(task)}")
messages = [TextMessage(content=task, source="user")]
elif isinstance(task, get_args(ChatMessage)[0]):
messages = [task] # type: ignore
elif isinstance(task, list):
if not task:
raise ValueError("Task list cannot be empty")
if not all(isinstance(msg, get_args(ChatMessage)[0]) for msg in task):
raise ValueError("All messages in task list must be valid ChatMessage types")
messages = task
if self._is_running:
raise ValueError("The team is already running, it cannot run again until it is stopped.")
@@ -389,7 +400,7 @@ class BaseGroupChat(Team, ABC):
# The group chat manager will start the group chat by relaying the message to the participants
# and the closure agent.
await self._runtime.send_message(
GroupChatStart(message=first_chat_message),
GroupChatStart(messages=messages),
recipient=AgentId(type=self._group_chat_manager_topic_type, key=self._team_id),
cancellation_token=cancellation_token,
)

View File

@@ -70,24 +70,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Stop the group chat.
return
# Validate the group state given the start message.
await self.validate_group_state(message.message)
# Validate the group state given the start messages
await self.validate_group_state(message.messages)
if message.message is not None:
# Log the start message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Relay the start message to the participants.
if message.messages is not None:
# Log all messages at once
await self.publish_message(
message, topic_id=DefaultTopicId(type=self._group_topic_type), cancellation_token=ctx.cancellation_token
GroupChatStart(messages=message.messages), topic_id=DefaultTopicId(type=self._output_topic_type)
)
# Append the user message to the message thread.
self._message_thread.append(message.message)
# Relay all messages at once to participants
await self.publish_message(
GroupChatStart(messages=message.messages),
topic_id=DefaultTopicId(type=self._group_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Check if the conversation should be terminated.
# Append all messages to thread
self._message_thread.extend(message.messages)
# Check termination condition after processing all messages
if self._termination_condition is not None:
stop_message = await self._termination_condition([message.message])
stop_message = await self._termination_condition(message.messages)
if stop_message is not None:
await self.publish_message(
GroupChatTermination(message=stop_message),
@@ -97,7 +101,7 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
await self._termination_condition.reset()
return
# Select a speaker to start the conversation.
# Select a speaker to start/continue the conversation
speaker_topic_type_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_topic_type_future)
@@ -166,8 +170,13 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
await self.reset()
@abstractmethod
async def validate_group_state(self, message: ChatMessage | None) -> None:
"""Validate the state of the group chat given the start message. This is executed when the group chat manager receives a GroupChatStart event."""
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
"""Validate the state of the group chat given the start messages.
This is executed when the group chat manager receives a GroupChatStart event.
Args:
messages: A list of chat messages to validate, or None if no messages are provided.
"""
...
@abstractmethod

View File

@@ -30,8 +30,8 @@ class ChatAgentContainer(SequentialRoutedAgent):
@event
async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> None:
"""Handle a start event by appending the content to the buffer."""
if message.message is not None:
self._message_buffer.append(message.message)
if message.messages is not None:
self._message_buffer.extend(message.messages)
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:

View File

@@ -1,3 +1,5 @@
from typing import List
from pydantic import BaseModel
from ...base import Response
@@ -7,8 +9,8 @@ from ...messages import AgentMessage, ChatMessage, StopMessage
class GroupChatStart(BaseModel):
"""A request to start a group chat."""
message: ChatMessage | None = None
"""An optional user message to start the group chat."""
messages: List[ChatMessage] | None = None
"""An optional list of messages to start the group chat."""
class GroupChatAgentResponse(BaseModel):

View File

@@ -126,17 +126,18 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
)
# Stop the group chat.
return
assert message is not None and message.message is not None
assert message is not None and message.messages is not None
# Validate the group state given the start message.
await self.validate_group_state(message.message)
# Validate the group state given all the messages.
await self.validate_group_state(message.messages)
# Log the start message.
# Log the message.
await self.publish_message(message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Outer Loop for first time
# Create the initial task ledger
#################################
self._task = self._content_to_str(message.message.content)
# Combine all message contents for task
self._task = " ".join([self._content_to_str(msg.content) for msg in message.messages])
planning_conversation: List[LLMMessage] = []
# 1. GATHER FACTS
@@ -184,7 +185,7 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
return
await self._orchestrate_step(ctx.cancellation_token)
async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:

View File

@@ -29,7 +29,7 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
)
self._next_speaker_index = 0
async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
async def reset(self) -> None:

View File

@@ -54,7 +54,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._allow_repeated_speaker = allow_repeated_speaker
self._selector_func = selector_func
async def validate_group_state(self, message: ChatMessage | None) -> None:
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
pass
async def reset(self) -> None:

View File

@@ -29,16 +29,19 @@ class SwarmGroupChatManager(BaseGroupChatManager):
)
self._current_speaker = participant_topic_types[0]
async def validate_group_state(self, message: ChatMessage | None) -> None:
"""Validate the start message for the group chat."""
# Check if the start message is a handoff message.
if isinstance(message, HandoffMessage):
if message.target not in self._participant_topic_types:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return
async def validate_group_state(self, messages: List[ChatMessage] | None) -> None:
"""Validate the start messages for the group chat."""
# Check if any of the start messages is a handoff message.
if messages:
for message in messages:
if isinstance(message, HandoffMessage):
if message.target not in self._participant_topic_types:
raise ValueError(
f"The target {message.target} is not one of the participants {self._participant_topic_types}. "
"If you are resuming Swarm with a new HandoffMessage make sure to set the target to a valid participant as the target."
)
return
# Check if there is a handoff message in the thread that is not targeting a valid participant.
for existing_message in reversed(self._message_thread):
if isinstance(existing_message, HandoffMessage):

View File

@@ -8,6 +8,7 @@ from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import Handoff, TaskResult
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
MultiModalMessage,
TextMessage,
@@ -21,7 +22,10 @@ from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.completion_usage import CompletionUsage
from utils import FileLogHandler
@@ -33,14 +37,14 @@ logger.addHandler(FileLogHandler("test_assistant_agent.log"))
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self._curr_index = 0
self.curr_index = 0
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self._curr_index]
self._curr_index += 1
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
@@ -90,7 +94,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="pass", role="assistant"),
)
],
created=0,
model=model,
@@ -101,7 +109,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
id="id2",
choices=[
Choice(
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
)
],
created=0,
@@ -115,7 +125,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
result = await agent.run(task="task")
@@ -133,14 +147,14 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[3].models_usage is None
# Test streaming.
mock._curr_index = 0 # pyright: ignore
mock.curr_index = 0 # Reset the mock
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
index += 1
# Test state saving and loading.
state = await agent.save_state()
@@ -234,7 +248,7 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
assert result.messages[3].models_usage.prompt_tokens == 10
# Test streaming.
mock._curr_index = 0 # pyright: ignore
mock.curr_index = 0 # pyright: ignore
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -248,7 +262,11 @@ async def test_run_with_tools_and_reflection(monkeypatch: pytest.MonkeyPatch) ->
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
@@ -293,7 +311,11 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
tool_use_agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
handoffs=[handoff],
)
assert HandoffMessage in tool_use_agent.produced_message_types
@@ -313,7 +335,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[3].models_usage is None
# Test streaming.
mock._curr_index = 0 # pyright: ignore
mock.curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
@@ -330,7 +352,11 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Hello", role="assistant"),
)
],
created=0,
model=model,
@@ -340,7 +366,10 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(name="assistant", model_client=OpenAIChatCompletionClient(model=model, api_key=""))
agent = AssistantAgent(
name="assistant",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
result = await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
@@ -351,14 +380,24 @@ async def test_multi_modal_task(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_invalid_model_capabilities() -> None:
model = "random-model"
model_client = OpenAIChatCompletionClient(
model=model, api_key="", model_capabilities={"vision": False, "function_calling": False, "json_output": False}
model=model,
api_key="",
model_capabilities={
"vision": False,
"function_calling": False,
"json_output": False,
},
)
with pytest.raises(ValueError):
agent = AssistantAgent(
name="assistant",
model_client=model_client,
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
tools=[
_pass_function,
_fail_function,
FunctionTool(_echo_function, description="Echo"),
],
)
with pytest.raises(ValueError):
@@ -369,3 +408,62 @@ async def test_invalid_model_capabilities() -> None:
# Generate a random base64 image.
img_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC"
await agent.run(task=MultiModalMessage(source="user", content=["Test", Image.from_base64(img_base64)]))
@pytest.mark.asyncio
async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(content="Response to message 1", role="assistant"),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
agent = AssistantAgent(
"test_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
)
# Create a list of chat messages
messages: List[ChatMessage] = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
]
# Test run method with list of messages
result = await agent.run(task=messages)
assert len(result.messages) == 3 # 2 input messages + 1 response message
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].content == "Message 1"
assert result.messages[0].source == "user"
assert isinstance(result.messages[1], TextMessage)
assert result.messages[1].content == "Message 2"
assert result.messages[1].source == "user"
assert isinstance(result.messages[2], TextMessage)
assert result.messages[2].content == "Response to message 1"
assert result.messages[2].source == "test_agent"
assert result.messages[2].models_usage is not None
assert result.messages[2].models_usage.completion_tokens == 5
assert result.messages[2].models_usage.prompt_tokens == 10
# Test run_stream method with list of messages
mock.curr_index = 0 # Reset mock index using public attribute
index = 0
async for message in agent.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1

View File

@@ -1025,3 +1025,48 @@ async def test_swarm_with_handoff_termination() -> None:
assert result.messages[1].content == "Transferred to second_agent."
assert result.messages[2].content == "Transferred to third_agent."
assert result.messages[3].content == "Transferred to non_existing_agent."
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_message_list() -> None:
# Create a simple team with echo agents
agent1 = _EchoAgent("Agent1", "First agent")
agent2 = _EchoAgent("Agent2", "Second agent")
termination = MaxMessageTermination(4) # Stop after 4 messages
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
# Create a list of messages
messages: List[ChatMessage] = [
TextMessage(content="Message 1", source="user"),
TextMessage(content="Message 2", source="user"),
TextMessage(content="Message 3", source="user"),
]
# Run the team with the message list
result = await team.run(task=messages)
# Verify the messages were processed in order
assert len(result.messages) == 4 # Initial messages + echo until termination
assert result.messages[0].content == "Message 1" # First message
assert result.messages[1].content == "Message 2" # Second message
assert result.messages[2].content == "Message 3" # Third message
assert result.messages[3].content == "Message 1" # Echo from first agent
assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"
# Test with streaming
await team.reset()
index = 0
async for message in team.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test with invalid message list
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
# Test with empty message list
with pytest.raises(ValueError, match="Task list cannot be empty"):
await team.run(task=[])

View File

@@ -72,9 +72,20 @@ async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
response = await society_of_mind_agent.run(task="Count to 10.")
assert len(response.messages) == 5
assert len(response.messages) == 4
assert response.messages[0].source == "user"
assert response.messages[1].source == "user"
assert response.messages[2].source == "assistant1"
assert response.messages[3].source == "assistant2"
assert response.messages[4].source == "society_of_mind"
assert response.messages[1].source == "assistant1"
assert response.messages[2].source == "assistant2"
assert response.messages[3].source == "society_of_mind"
# Test save and load state.
state = await society_of_mind_agent.save_state()
assert state is not None
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
await society_of_mind_agent2.load_state(state)
state2 = await society_of_mind_agent2.save_state()
assert state == state2